﻿#pragma kernel ScanInBucketExclusive
#pragma kernel ScanAddBucketResult

#define THREADS_PER_GROUP 512 // Ensure that this equals the 'threadsPerGroup' const in the host script. Must be an odd power of 2.

// These must be a multiple of THREADS_PER_GROUP. 
StructuredBuffer<uint> _Input;
RWStructuredBuffer<uint> _Result;
RWStructuredBuffer<uint> _BlockSum;
uint count;

groupshared uint bucket[THREADS_PER_GROUP];

// Scan in each bucket.
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanInBucketExclusive(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID, uint GI : SV_GroupIndex)
{
    if (DTid < count) {
        bucket[GI] = _Input[DTid];
    } else {
        bucket[GI] = 0;
    }

    GroupMemoryBarrierWithGroupSync();

    uint stride;

    // up-sweep
    [unroll]
    for (stride = 2; stride <= THREADS_PER_GROUP; stride <<= 1)
    {
        GroupMemoryBarrierWithGroupSync();
        if (((GI + 1) % stride) == 0) 
        {
            const uint half_stride = (stride >> 1);
            bucket[GI] += bucket[GI - half_stride];
        }
    }

    // Without this barrier, setting tg_mem[-1] to 0 may not be properly
    // propagated across the entire threadgroup.
    GroupMemoryBarrierWithGroupSync();
    if (GI == THREADS_PER_GROUP - 1)
    {
        // clear the last element
        _BlockSum[Gid] = bucket[GI];
        bucket[GI] = 0;
    }

    // down-sweep
    [unroll]
    for (stride = THREADS_PER_GROUP; stride > 1; stride >>= 1)
    {
        GroupMemoryBarrierWithGroupSync();

        if (((GI + 1) % stride) == 0) 
        {
            const uint half_stride = (stride >> 1);
            const uint prev_idx = GI - half_stride;
            const int tmp = bucket[prev_idx];
            bucket[prev_idx] = bucket[GI];
            bucket[GI] += tmp;
        }
    }

    GroupMemoryBarrierWithGroupSync();

    if (DTid < count) 
        _Result[DTid] = bucket[GI];
}

// Add the bucket scanned result to each bucket to get the final result.
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanAddBucketResult(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID)
{
    if (DTid < count)
        _Result[DTid] += _Input[Gid];
}
